/*
 * Copyright (c) Hisilicon Technologies Co., Ltd. 2022-2022. All rights reserved.
 * Description: 轻量化配置文件解析模块
 */
#include "quantize_cfg_parser.h"

#include <fstream>
#include <map>
#include <algorithm>
#include <vector>
#include <nlohmann/json.hpp>

#include "securec.h"
#include "graph/types.h"
#include "quantize_types.h"
#include "base/common/file_util/file_util.h"
#include "framework/infra/log/log.h"

using namespace std;
using namespace ge;

namespace hiai {
using json = nlohmann::json;

namespace {
#define JSON_GET_IMP(FuncName, ValueType, Type) \
    bool JsonGet##FuncName(const json& j, const string& key, Type& value) \
    { \
        auto iter = j.find(key); \
        if (iter == j.end()) { \
            return false; \
        } \
        if (!iter.value().ValueType()) { \
            FMK_LOGE("invalid value type for key [%s].", key.c_str()); \
            return false; \
        } \
        value = iter.value().get<Type>(); \
        return true; \
    }

JSON_GET_IMP(Int, is_number_integer, int32_t)
JSON_GET_IMP(String, is_string, std::string)
JSON_GET_IMP(Bool, is_boolean, bool)

void UnmaskBuffer(uint8_t* buffer, size_t size)
{
    // 解掩码
    string version("ddkv320");
    uint64_t versionNum = version.size();
    for (size_t i = 0; i < size; ++i) {
        buffer[i] ^= static_cast<uint8_t>(version[i % versionNum]);
    }
}

void GenerateUseWeightNameFlag(const json& j, ModelLightWeightParams& params)
{
    const std::string KEY_QUANTIZE_USE_WEIGHT_NAME = "useWeightName";
    params.useWeightName = false;

    int32_t value = 0;
    if (JsonGetInt(j, KEY_QUANTIZE_USE_WEIGHT_NAME, value)) {
        params.useWeightName = (value == 1);
    }
}

bool GenerateFloatArrayConfig(const json& quantData, const std::string& key, std::vector<float>& array)
{
    vector<float> tmp;
    auto iter = quantData.find(key);
    if (iter == quantData.end()) {
        FMK_LOGE("can't find key [%s].", key.c_str());
        return false;
    }

    if (iter.value().is_number_float()) {
        tmp.push_back(iter.value().get<float>());
    } else if (iter.value().is_array()) {
        for (float f : iter.value()) {
            tmp.push_back(f);
        }
    } else {
        FMK_LOGE("invalid type for key [%s].", key.c_str());
        return false;
    }

    array.swap(tmp);
    return true;
}

ge::DataType String2DataType(const std::string& strType)
{
    const std::string QUANTIZE_DATE_TYPE_FLOAT = "FLOAT";
    const std::string QUANTIZE_DATE_TYPE_FP16 = "FP16";
    const std::string QUANTIZE_DATE_TYPE_UINT16 = "UINT16";
    const std::string QUANTIZE_DATE_TYPE_INT16 = "INT16";
    const std::string QUANTIZE_DATE_TYPE_UINT8 = "UINT8";
    const std::string QUANTIZE_DATE_TYPE_INT8 = "INT8";
    const std::string QUANTIZE_DATE_TYPE_INT4 = "INT4";
    const std::string QUANTIZE_DATE_TYPE_INT2 = "INT2";
    const std::string QUANTIZE_DATE_TYPE_INT32 = "INT32";
    const std::map<const std::string, ge::DataType> DATA_TYPE_MAPS {
        { QUANTIZE_DATE_TYPE_FLOAT, DT_FLOAT },
        { QUANTIZE_DATE_TYPE_FP16, DT_FLOAT16 },
#ifdef AI_SUPPORT_INT16_INT8_QUANTIZE
        { QUANTIZE_DATE_TYPE_UINT16, DT_UINT16 },
        { QUANTIZE_DATE_TYPE_INT16, DT_INT16 },
#endif
        { QUANTIZE_DATE_TYPE_UINT8, DT_UINT8 },
        { QUANTIZE_DATE_TYPE_INT8, DT_INT8 },
        { QUANTIZE_DATE_TYPE_INT4, DT_INT4 },
        { QUANTIZE_DATE_TYPE_INT2, DT_2BIT },
        { QUANTIZE_DATE_TYPE_INT32, DT_INT32 }
    };

    auto iter = DATA_TYPE_MAPS.find(strType);
    if (iter != DATA_TYPE_MAPS.end()) {
        return iter->second;
    }

    return DT_UNDEFINED;
}

bool GetQuantDataType(const json& j, const string& inputValue, ge::DataType& dataType)
{
    string input;
    if (!JsonGetString(j, inputValue, input)) {
        FMK_LOGE("Json data can't find key:%s.", inputValue.c_str());
        return false;
    }
    dataType = String2DataType(input);
    return true;
}

bool GenerateInputConfig(const json& quantData, QuantizeConfig& quantizeConfig)
{
    const std::string KEY_QUANTIZE_INPUT = "input";
    const std::string KEY_QUANTIZE_INPUT_SCALE = "inputScale";
    const std::string KEY_QUANTIZE_INPUT_OFFSET = "inputOffset";

    ge::DataType inputDataType;
    if (!GetQuantDataType(quantData, KEY_QUANTIZE_INPUT, inputDataType)) {
        return false;
    }
    if (inputDataType == DT_UNDEFINED) {
        FMK_LOGE("Invalid data type for [%s].", KEY_QUANTIZE_INPUT.c_str());
        return false;
    }
    inputDataType = (inputDataType == DT_INT8) ? DT_UINT8 : inputDataType;
    quantizeConfig.inputDataType = inputDataType;

    return GenerateFloatArrayConfig(quantData, KEY_QUANTIZE_INPUT_SCALE, quantizeConfig.inputScale) &&
        GenerateFloatArrayConfig(quantData, KEY_QUANTIZE_INPUT_OFFSET, quantizeConfig.inputOffset);
}

bool GenerateWeightConfig(const json& quantData, QuantizeConfig& quantizeConfig)
{
    const std::string KEY_QUANTIZE_WEIGHT = "weight";
    const std::string KEY_QUANTIZE_WEIGHT_SCALE = "weightScale";
    const std::string KEY_QUANTIZE_WEIGHT_OFFSET = "weightOffset";

    ge::DataType weightDataType;
    if (!GetQuantDataType(quantData, KEY_QUANTIZE_WEIGHT, weightDataType)) {
        return false;
    }
    if (weightDataType == DT_UNDEFINED) {
        FMK_LOGE("Invalid data type for [%s].", KEY_QUANTIZE_WEIGHT.c_str());
        return false;
    }
    quantizeConfig.weightDataType = weightDataType;

    return GenerateFloatArrayConfig(quantData, KEY_QUANTIZE_WEIGHT_SCALE, quantizeConfig.weightScale) &&
        GenerateFloatArrayConfig(quantData, KEY_QUANTIZE_WEIGHT_OFFSET, quantizeConfig.weightOffset);
}

void GenerateQuantizeExtConfig(const json& quantData, string& quantInfoExt, bool& hasQuantInfoExt)
{
    const std::string KEY_QUANTIZE_INFO_EXT = "quantInfoExt";

    string value = "";
    if (JsonGetString(quantData, KEY_QUANTIZE_INFO_EXT, value) && value != "") {
        quantInfoExt = value;
        hasQuantInfoExt = true;
    }
}

bool GenerateQuantizeConfig(const json& opData, QuantizeConfig& quantizeConfig)
{
    if (!GenerateInputConfig(opData, quantizeConfig)) {
        return false;
    }
    if (!GenerateWeightConfig(opData, quantizeConfig)) {
        return false;
    }
    GenerateQuantizeExtConfig(opData, quantizeConfig.quantInfoExt, quantizeConfig.hasQuantInfoExt);

    return true;
}

bool GetQuantOperatorType(const json& opData, const std::string& key, OperatorType& operatorType)
{
    map<string, OperatorType> quantTypeActions { { "Quantize", OperatorType::QUANTIZE },
        { "DeQuantize", OperatorType::DEQUANTIZE }, { "ReQuantize", OperatorType::REQUANTIZE },
        { "AntiQuantize", OperatorType::ANTIQUANTIZE } };

    string quantType = "";
    if (!JsonGetString(opData, key, quantType)) {
        FMK_LOGE("Json data can't find key:%s.", key.c_str());
        return false;
    }
    auto it = quantTypeActions.find(quantType);
    if (it == quantTypeActions.end()) {
        FMK_LOGE("Quant type:%s is not supported.", quantType.c_str());
        return false;
    }
    operatorType = it->second;
    return true;
}

bool GenerateV2QuantParams(const json& opData, std::vector<QuantizeParams>& quantParams)
{
    const std::string KEY_QUANTIZE_INDEX = "index";
    const std::string KEY_QUANTIZE_QUANTPARAM = "QuantParam";
    const std::string KEY_QUANTIZE_QUANTTYPE = "quantType";
    const std::string KEY_QUANTIZE_DATATYPE = "dataType";
    const std::string KEY_QUANTIZE_SCALE = "scale";
    const std::string KEY_QUANTIZE_OFFSET = "offset";

    QuantizeParams param;
    int32_t index = 0;
    if (!JsonGetInt(opData, KEY_QUANTIZE_INDEX, index) || index < 0) {
        FMK_LOGE("Get index fail.");
        return false;
    }
    param.index = static_cast<uint32_t>(index);
    auto iter = opData.find(KEY_QUANTIZE_QUANTPARAM);
    if (iter != opData.end()) {
        param.operatorParams.resize(iter.value().size());
        for (auto opIt = iter.value().begin(); opIt != iter.value().end(); opIt++) {
            OperatorParam operParam;
            int32_t operIndex = 0;
            if (!JsonGetInt(*opIt, KEY_QUANTIZE_INDEX, operIndex) || operIndex < 0) {
                FMK_LOGE("Get operIndex fail.");
                return false;
            }
            operParam.operIndex = static_cast<uint32_t>(operIndex);
            if (!GetQuantOperatorType(*opIt, KEY_QUANTIZE_QUANTTYPE, operParam.operType)) {
                FMK_LOGE("Get quantType fail.");
                return false;
            }

            if (!GetQuantDataType(*opIt, KEY_QUANTIZE_DATATYPE, operParam.dataType)) {
                FMK_LOGE("Get dataType fail.");
                return false;
            }

            if (!GenerateFloatArrayConfig(*opIt, KEY_QUANTIZE_SCALE, operParam.scale)) {
                FMK_LOGE("Get scale fail.");
                return false;
            }

            if (!GenerateFloatArrayConfig(*opIt, KEY_QUANTIZE_OFFSET, operParam.offset)) {
                FMK_LOGE("Get offset fail.");
                return false;
            }
            param.operatorParams[operParam.operIndex] = operParam;
        }
    }

    quantParams.push_back(std::move(param));
    return true;
}

bool GenerateQuantizeV2Config(const json& opData, QuantizeV2Config& quantizeConfig)
{
    const std::string KEY_QUANTIZE_INPUT = "input";
    const std::string KEY_QUANTIZE_OUTPUT = "output";
    const std::string KEY_QUANTIZE_ISONESIDEQUANTIZE = "isOneSideQuantize";

    auto iter = opData.find(KEY_QUANTIZE_INPUT);
    if (iter != opData.end()) {
        for (auto opIt = iter.value().begin(); opIt != iter.value().end(); opIt++) {
            if (!GenerateV2QuantParams(*opIt, quantizeConfig.inputQuantParams)) {
                return false;
            }
        }
    }

    iter = opData.find(KEY_QUANTIZE_OUTPUT);
    if (iter != opData.end()) {
        for (auto opIt = iter.value().begin(); opIt != iter.value().end(); opIt++) {
            if (!GenerateV2QuantParams(*opIt, quantizeConfig.outputQuantParams)) {
                return false;
            }
        }
    }

    if (!JsonGetBool(opData, KEY_QUANTIZE_ISONESIDEQUANTIZE, quantizeConfig.isOneSideQuantize)) {
        quantizeConfig.isOneSideQuantize = false;
    }

    GenerateQuantizeExtConfig(opData, quantizeConfig.quantInfoExt, quantizeConfig.hasQuantInfoExt);

    return true;
}

bool GenerateOpConfig(const json& opData, const string& name, map<string, QuantizeConfig>& quantizeConfigs)
{
    QuantizeConfig quantizeConfig;
    if (!GenerateQuantizeConfig(opData, quantizeConfig)) {
        FMK_LOGE("Generate quantizeConfig from Json failed, node name is %s.", name.c_str());
        return false;
    }

    auto result = quantizeConfigs.emplace(name, quantizeConfig);
    return result.second;
}

bool GenerateOpV2Config(const json& opData, const string& name, map<string, QuantizeV2Config>& quantizeConfigs)
{
    QuantizeV2Config quantizeConfig;
    if (!GenerateQuantizeV2Config(opData, quantizeConfig)) {
        FMK_LOGE("Generate quantizeConfig from Json failed, node name is %s.", name.c_str());
        return false;
    }

    auto result = quantizeConfigs.emplace(name, quantizeConfig);
    return result.second;
}

bool ParseOpQuantizeConfig(const json& opData, ModelLightWeightParams& params)
{
    const std::string KEY_COMPRESS_NAME = "name";
    const std::string KEY_QUANTIZE_QUANT = "quant";
    string name;
    if (!JsonGetString(opData, KEY_COMPRESS_NAME, name)) {
        FMK_LOGE("Json data can't find key:%s.", KEY_COMPRESS_NAME.c_str());
        return false;
    }

    auto iter = opData.find(KEY_QUANTIZE_QUANT);
    if (iter == opData.end()) {
        FMK_LOGE("can't find key [%s].", KEY_QUANTIZE_QUANT.c_str());
        return false;
    }

    bool ret = false;
    if (params.version == "") {
        ret = GenerateOpConfig(iter.value(), name, params.quantizeConfigs);
    } else {
        ret = GenerateOpV2Config(iter.value(), name, params.quantizeV2Configs);
    }
    if (!ret) {
        FMK_LOGE("Generate op config fail, op:%s, version:%s.", name.c_str(), params.version.c_str());
        return false;
    }

    return true;
}

void GetConfigVersion(const json& j, std::string& version)
{
    const std::string KEY_COMPRESS_VERSION = "Version";

    version = "";
    string value;
    if (JsonGetString(j, KEY_COMPRESS_VERSION, value)) {
        version = value;
    }
}

Status ParseOpQuantizeConfigs(const json& j, ModelLightWeightParams& params)
{
    const std::string KEY_COMPRESS_MLWP = "ModelLightWeightParameter";
    auto iter = j.find(KEY_COMPRESS_MLWP);
    if (iter == j.end()) {
        FMK_LOGE("can't find key [%s].", KEY_COMPRESS_MLWP.c_str());
        return hiai::FAILURE;
    }

    for (auto opIt = iter.value().begin(); opIt != iter.value().end(); opIt++) {
        if (!ParseOpQuantizeConfig(*opIt, params)) {
            FMK_LOGE("parse op quantize config fail, version:%s.", params.version.c_str());
            return hiai::FAILURE;
        }
    }
    return hiai::SUCCESS;
}

hiai::Status ParseQuantizeConfigs(uint8_t* buffer, size_t size, ModelLightWeightParams& params)
{
    UnmaskBuffer(buffer, size);

    json j = json::parse(buffer, buffer + size, nullptr, false);
    if (j.type() == json::value_t::discarded) {
        FMK_LOGE("Convert stream to json failed, please check light tool version.");
        return hiai::FAILURE;
    }
    GetConfigVersion(j, params.version);
    if (params.version == "") {
        GenerateUseWeightNameFlag(j, params);
    }

    return ParseOpQuantizeConfigs(j, params);
}
} // namespace

hiai::Status QuantizeCfgParser::ParseConfigFromFile(const string& file, ModelLightWeightParams& params)
{
    if (file.empty()) {
        FMK_LOGE("File is null.");
        return hiai::FAILURE;
    }

    uint8_t* buffer = nullptr;
    size_t size = 0;
    if (ReadFile(file.c_str(), buffer, size) != 0) {
        FMK_LOGE("read file %s failed.", file.c_str());
        return hiai::FAILURE;
    }

    Status ret = ParseQuantizeConfigs(buffer, size, params);
    delete[] buffer;
    return ret;
}

hiai::Status QuantizeCfgParser::ParseConfigFromBuffer(uint8_t* buffer, size_t size, ModelLightWeightParams& params)
{
    if (buffer == nullptr || size == 0) {
        FMK_LOGE("Quant config buffer is empty.");
        return hiai::FAILURE;
    }
    vector<uint8_t> tmpBuffer(size + 1);
    errno_t ret = memcpy_s(tmpBuffer.data(), size + 1, buffer, size);
    if (ret != EOK) {
        FMK_LOGE("memcpy_s fail, ret:%d.", ret);
        return hiai::FAILURE;
    }
    tmpBuffer[size] = '\0';

    return ParseQuantizeConfigs(tmpBuffer.data(), size, params);
}
} // namespace hiai
